{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "42_6VTj7l_xS"
      },
      "source": [
        "# Tutorial: Exporting StableHLO from PyTorch\n",
        "\n",
        "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)][pytorch-tutorial-colab]\n",
        "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)][pytorch-tutorial-kaggle]\n",
        "\n",
        "PyTorch is a popular library for building deep learning models.\n",
        "In this tutorial, you will learn to export a PyTorch model to StableHLO, and then directly to TensorFlow SavedModel.\n",
        "\n",
        "## Tutorial Setup\n",
        "\n",
        "### Install required dependencies\n",
        "\n",
        "We use `torch` and `torchvision` to get a [ResNet18 model](https://pytorch.org/vision/stable/models/resnet.html) model, and `torch_xla` to export it to StableHLO.\n",
        "We also need to install `tensorflow` to work with SavedModel, and recommend using `tensorflow-cpu` or `tf-nightly` for this tutorial.\n",
        "\n",
        "[pytorch-tutorial-colab]: https://colab.research.google.com/github/openxla/stablehlo/blob/main/docs/tutorials/pytorch-export.ipynb\n",
        "[pytorch-tutorial-kaggle]: https://kaggle.com/kernels/welcome?src=https://github.com/openxla/stablehlo/blob/main/docs/tutorials/pytorch-export.ipynb"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-GTAijyd3rOf"
      },
      "outputs": [],
      "source": [
        "!pip install torch_xla==2.5.0 torch==2.5.0 torchvision==0.20.0 tensorflow-cpu"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "V_AtPpV30Bt8"
      },
      "source": [
        "## Export PyTorch model to StableHLO\n",
        "\n",
        "The general set of steps for exporting a PyTorch model to StableHLO is:\n",
        "1. Use [PyTorch's `torch.export` API](https://pytorch.org/docs/2.5/export.html#torch-export) to generate an exported FX graph (i.e., `ExportedProgram`)\n",
        "2. Use [PyTorch/XLA's `torch_xla.stablehlo` API](https://pytorch.org/xla/master/features/stablehlo.html) to convert the `ExportedProgram` to StableHLO\n",
        "\n",
        "### Export model to FX graph using `torch.export`\n",
        "\n",
        "This step uses vanilla PyTorch APIs to export a `resnet18` model from `torchvision`.\n",
        "Sample inputs are required for graph tracing, we use a `tensor<4x3x224x224xf32>` in this case."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "GhIpxnx5fuxy"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torchvision\n",
        "from torch.export import export\n",
        "\n",
        "resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)\n",
        "sample_input = (torch.randn(4, 3, 224, 224), )\n",
        "exported = export(resnet18, sample_input)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zuMAr3WO1PBk"
      },
      "source": [
        "### Export FX graph to StableHLO using `torch_xla.stablehlo`\n",
        "\n",
        "Once we have an exported FX graph, we can convert it to StableHLO using `exported_program_to_stablehlo` in the `torch_xla.stablehlo` module.\n",
        "\n",
        "We can then look at the exported StableHLO program with `get_stablehlo_text`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "xiP7psIQgUc-",
        "outputId": "0de7d5bc-01a1-4e96-b2ce-19faa9965459"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "WARNING:root:Defaulting to PJRT_DEVICE=CPU\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "module @IrToHlo.484 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {\n",
            "  func.func @main(%arg0: tensor<1000xf32>, %arg1: tensor<1000x512xf32>, %arg2: tensor<512xf32>, %arg3: tensor<512xf32>, %arg4: tensor<512xf32>, %arg5: tensor<512xf32>, %arg6: tensor<512x256x1x1xf32>, %arg7: tensor<256xf32>, %arg8: tensor<256xf32>, %arg9: tensor<256xf32>, %arg10: tensor<256xf32>, %arg11: tensor<256x128x1x1xf32>, %arg12: tensor<128xf32>, %arg13: tensor<128xf32>, %arg14: tensor<128xf32>, %arg15: tensor<128xf32>, %arg16: tensor<128x64x1x1xf32>, %arg17: tensor<64xf32>, %arg18: tensor<64xf32>, %arg19: tensor<64xf32>, %arg20: tensor<64xf32>, %arg21: tensor<64x3x7x7xf32>, %arg22: tensor<4x3x224x224xf32>, %arg23: tensor<64xf32>, %arg24: tensor<64xf32>, %arg25: tensor<64xf32>, %arg26: tensor<64xf32>, %arg27: tensor<64x64x3x3xf32>, %arg28: tensor<64xf32>, %arg29: tensor<64xf32>, %arg30: tensor<64xf32>, %arg31: tensor<64xf32>, %arg32: tensor<64x64x3x3xf32>, %arg33: tensor<64xf32>, %arg34: tensor<64xf32>, %arg35: tensor<64xf32>, %arg36: tensor<64xf32>, %arg37: tensor<64x64x3x3xf32>, %arg38: tensor<64xf32>, %arg39: tensor<64xf32>, %arg40: tensor<64xf32>, %arg41: tensor<64xf32>, %arg42: tensor<64x64x3x3xf32>, %arg43: tensor<128xf32>, %arg44: tensor<128xf32>, %arg45: tensor<128xf32>, %arg46: tensor<128xf32>, %arg47: tensor<128x128x3x3xf32>, %arg48: tensor<128xf32>, %arg49: tensor<128xf32>, %arg50: tensor<128xf32>, %arg51: tensor<128xf32>, %arg52: tensor<128x64x3x3xf32>, %arg53: tensor<128xf32>, %arg54: tensor<128xf32>, %arg55: tensor<128xf32>, %arg56: tensor<128xf32>, %arg57: tensor<128x128x3x3xf32>, %arg58: tensor<128xf32>, %arg59: tensor<128xf32>, %arg60: tensor<128xf32>, %arg61: tensor<128xf32>, %arg62: tensor<128x128x3x3xf32>, %arg63: tensor<256xf32>, %arg64: tensor<256xf32>, %arg65: tensor<256xf32>, %arg66: tensor<256xf32>, %arg67: tensor<256x256x3x3xf32>, %arg68: tensor<256xf32>, %arg69: tensor<256xf32>, %arg70: tensor<256xf32>, %arg71: tensor<256xf32>, %arg72: tensor<256x128x3x3xf32>, %arg73: tensor<256xf32>, %arg74: tensor<256xf32>, %arg75: tensor<256xf32>, %arg76: tensor<256xf32>, %arg77: tensor<256x256x3x3xf32>, %arg78: tensor<256xf32>, %arg79: tensor<256xf32>, %arg80: tensor<256xf32>, %arg81: tensor<256xf32>, %arg82: tensor<256x256x3x3xf32>, %arg83: tensor<512xf32>, %arg84: tensor<512xf32>, %arg85: tensor<512xf32>, %arg86: tensor<512xf32>, %arg87: tensor<512x512x3x3xf32>, %arg88: tensor<512xf32>, %arg89: tensor<512xf32>, %arg90: tensor<512xf32>, %arg91: tensor<512xf32>, %arg92: tensor<512x256x3x3xf32>, %arg93: tensor<512xf32>, %arg94: tensor<512xf32>, %arg95: tensor<512xf32>, %arg96: tensor<512xf32>, %arg97: tensor<512x512x3x3xf32>, %arg98: tensor<512xf32>, %arg99: tensor<512xf32>, %arg100: tensor<512xf32>, %arg101: tensor<512xf32>, %arg102: tensor<512x512x3x3xf32>) -> tensor<4x1000xf32> {\n",
            "    %cst = stablehlo.constant dense<0.0204081628> : tensor<4x512xf32>\n",
            "    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<4x512x7x7xf32>\n",
            "    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<4x256x14x14xf32>\n",
            "    %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<4x128x28x28xf32>\n",
            "    %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<4x64x56x56xf32>\n",
            "    %cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<4x64x112x112xf32>\n",
            "    %cst_5 = stablehlo.constant dense<0xFF800000> : tensor<f32>\n",
            "    %cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<f32>\n",
            "    %0 = stablehlo.convolution(%arg22, %arg21) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 2], pad = [[3, 3], [3, 3]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [false, false]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<4x3x224x224xf32>, tensor<64x3x7x7xf32>) -> tensor<4x64x112x112xf32>\n",
            "    %output, %batch_mean, %batch_var = \"stablehlo.ba \n",
            "...\n"
          ]
        }
      ],
      "source": [
        "from torch_xla.stablehlo import exported_program_to_stablehlo\n",
        "\n",
        "stablehlo_program = exported_program_to_stablehlo(exported)\n",
        "print(stablehlo_program.get_stablehlo_text('forward')[0:4000],\"\\n...\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2Ujt2OjtpERw"
      },
      "source": [
        "_Tip:_\n",
        "\n",
        "_Dynamic batch dimensions can be specified as a part of the initial `torch.export` step._\n",
        "\n",
        "_`torch_xla`'s support for exporting dynamic models is limited, for these cases we recommend using [`torch_xla2`](https://github.com/pytorch/xla/tree/master/experimental/torch_xla2) for this. This lowering path leverages JAX for lowering to StableHLO, and has high opset coverage with much broader support for exported programs with dynamic shapes._"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ySTEYXzG1fU6"
      },
      "source": [
        "### Save and reload StableHLO\n",
        "\n",
        "`StableHLOGraphModule` has methods to `save` and `load` StableHLO artifacts.\n",
        "This stores StableHLO portable bytecode artifacts which have complete forward and backward compatibility guarantees."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "jeVTUs7jh8lk",
        "outputId": "4c248260-7442-495c-932b-a618e9eb2c67"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "constants  data  functions\n",
            "forward.bytecode  forward.meta\tforward.mlir\n"
          ]
        }
      ],
      "source": [
        "from torch_xla.stablehlo import StableHLOGraphModule\n",
        "\n",
        "# Save to tmp\n",
        "stablehlo_program.save('/tmp/stablehlo_dir')\n",
        "!ls /tmp/stablehlo_dir\n",
        "!ls /tmp/stablehlo_dir/functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "tensor([[-2.3258, -0.9606, -0.9439,  ...,  0.3519,  0.6261,  2.3971],\n",
            "        [ 1.6479, -0.0268,  1.0511,  ..., -1.2512,  2.2042,  1.8865],\n",
            "        [ 0.1756, -0.3658, -0.0651,  ...,  0.0661,  2.1358,  0.5009],\n",
            "        [-1.6709, -0.7363, -2.0963,  ..., -1.3716,  0.3321, -0.9199]],\n",
            "       device='xla:0')\n"
          ]
        }
      ],
      "source": [
        "# Reload and execute - Stable serialization, forward / backward compatible.\n",
        "reloaded = StableHLOGraphModule.load('/tmp/stablehlo_dir')\n",
        "print(reloaded(sample_input[0]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "_Note: You can also use convenience wrappers like `save_torch_model_as_stablehlo` to export and save. Learn more in the [PyTorch/XLA documentation on exporting to StableHLO](https://pytorch.org/xla/master/features/stablehlo.html#other-common-wrappers)._"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NsJyDxxnjd4B"
      },
      "source": [
        "## Export to TensorFlow SavedModel\n",
        "\n",
        "It is common to want to export a StableHLO model to TensorFlow SavedModel for interoperability with existing compilation pipelines, existing TensorFlow tooling, or serving via [TensorFlow Serving](https://github.com/tensorflow/serving).\n",
        "\n",
        "PyTorch/XLA's `torch_xla.tf_saved_model_integration` module makes it easy to pack StableHLO into a SavedModel, which can be loaded back and executed.\n",
        "\n",
        "### Export to SavedModel with `torch_xla.tf_saved_model_integration`\n",
        "\n",
        "We use the `save_torch_module_as_tf_saved_model` function for this conversion, which uses the `torch.export` and `torch_xla.stablehlo.exported_program_to_stablehlo` functions under the hood.\n",
        "\n",
        "The input to the API is a PyTorch model, and we use the same `resnet18` from the previous examples."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1g24aShhjG6e",
        "outputId": "382d14f5-7b7b-4297-8af4-0999b7ba4f3f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "assets\tfingerprint.pb\tsaved_model.pb\tvariables\n"
          ]
        }
      ],
      "source": [
        "from torch_xla.tf_saved_model_integration import save_torch_module_as_tf_saved_model\n",
        "\n",
        "save_torch_module_as_tf_saved_model(\n",
        "    resnet18,         # original pytorch torch.nn.Module\n",
        "    sample_input,     # sample inputs used to trace\n",
        "    '/tmp/resnet_tf'  # directory for tf.saved_model\n",
        ")\n",
        "\n",
        "!ls /tmp/resnet_tf/"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "shmmhYP76eX2"
      },
      "source": [
        "### Reload and call the SavedModel\n",
        "\n",
        "Now we can load that SavedModel and compile using our `sample_input` from a previous example.\n",
        "\n",
        "_Note: The restored model does *not* require PyTorch or PyTorch/XLA to run, just XLA._"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2TZMQyJj6fHy",
        "outputId": "c2f97232-a623-4146-d369-ef75c2033136"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
            "I0000 00:00:1730760467.760638    8492 service.cc:148] XLA service 0x7ede002016e0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:\n",
            "I0000 00:00:1730760467.760777    8492 service.cc:156]   StreamExecutor device (0): Host, Default Version\n",
            "I0000 00:00:1730760468.613723    8492 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[<tf.Tensor: shape=(4, 1000), dtype=float32, numpy=\n",
            "array([[-2.3257551 , -0.96061766, -0.9439326 , ...,  0.35189423,\n",
            "         0.62605226,  2.3971176 ],\n",
            "       [ 1.6479174 , -0.02676968,  1.0511047 , ..., -1.2511721 ,\n",
            "         2.2041895 ,  1.8865337 ],\n",
            "       [ 0.17559683, -0.365776  , -0.06507193, ...,  0.06606296,\n",
            "         2.135755  ,  0.500913  ],\n",
            "       [-1.6709077 , -0.7362997 , -2.0962732 , ..., -1.3716122 ,\n",
            "         0.33205754, -0.91991633]], dtype=float32)>]\n"
          ]
        }
      ],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "loaded_m = tf.saved_model.load('/tmp/resnet_tf')\n",
        "print(loaded_m.f(tf.constant(sample_input[0].numpy())))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fSk7HFVk8CqR"
      },
      "source": [
        "## Troubleshooting\n",
        "\n",
        "### Version mismatch\n",
        "\n",
        "Ensure that you have the same version of PyTorch/XLA and PyTorch. Version mismatch can result in import errors, as well as some runtime issues.\n",
        "\n",
        "### Export bugs\n",
        "\n",
        "If your program fails to export due to a bug in the PyTorch/XLA bridge, open an issue on GitHub with a reproducible example:\n",
        "\n",
        "- Issues in `torch.export`: Report these in the upstream [pytorch/pytorch](https://github.com/pytorch/pytorch) repository\n",
        "- Issues in `torch_xla.stablehlo`: Open a ticket on [pytorch/xla](https://github.com/pytorch/xla) repository"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.10"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
